# @package _global_
defaults:
  - /experiment/pile/gpt3xl-flash-8k.yaml

model:
  config:
    n_embd: 2560
    n_head: 32
    n_layer: 32
    initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
    mlp_checkpoint_lvl: 0

datamodule:
  batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}

train:
  optimizer:
    lr: 1.6e-4
